from keras.datasets import mnist
import matplotlib.pyplot as plt

# 导入mnist数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 查看训练数据
print("训练集图片尺寸", train_images.shape)
print("训练集标签数目", len(train_labels))
print("训练集标签内容", train_labels)

# 绘制训练标签直方图
plt.hist(train_labels, bins=range(11), rwidth=0.5, align="left")
plt.xticks(range(10))
plt.xlim(-1, 10)
plt.savefig("train_labels.jpg")
plt.close()

# 查看测试数据
print("测试集图片尺寸", test_images.shape)
print("测试集标签数目", len(test_labels))
print("测试集标签内容", test_labels)

# 绘制测试标签直方图
plt.hist(test_labels, bins=range(11), rwidth=0.5, align="left")
plt.xticks(range(10))
plt.xlim(-1, 10)
plt.savefig("test_labels.jpg")
plt.close()
